# Estimate power plant response function
library(pacman)
p_load(
  here, fst, data.table, fixest, ggplot2, lubridate, 
  dplyr, janitor, purrr, furrr, stringr, AER, tictoc,
  magrittr
)

unit_info_dt = read.fst(
  here("Data/electricity-generation/unit-info-dt.fst"),
  as.data.table = TRUE
)[open_status == 'fully operational']

# Loading the electricity generation data
elec_gen_dt=
  read.fst(
    path = here("Data/electricity-generation/elec-gen-dt.fst"),
    as.data.table = TRUE
  )|> 
  setkey(orispl_code, unitid, utc_time) 

# Formula changes depending on where the plant is located 
# Table with all of the formulas 
fml_dt_raw = 
  data.table(
    nerc_adj = c(
      'TRE',
      rep('WECC',2),
      rep('CAL', 2),
      rep('MRO', 7),
      rep('NPCC', 7),
      rep('RFC', 7),
      rep('SERC', 7)
    ),
    control_1 = c(
      NA, 
      NA, 'cal', 
      NA, 'wecc', 
      NA, 'npcc','npcc','npcc','rfc','rfc','serc', 
      NA, 'mro','mro','mro','rfc','rfc','serc', 
      NA, 'mro','mro','mro','npcc','npcc','serc', 
      NA, 'mro','mro','mro','npcc','npcc','rfc'
    ),
    control_2 = c(
      NA, 
      NA, NA, 
      NA, NA, 
      NA, 'rfc','rfc','serc','serc', NA, NA, 
      NA, 'rfc','rfc','serc','serc', NA, NA, 
      NA, 'npcc','npcc','serc','serc', NA, NA, 
      NA, 'npcc','npcc','rfc','rfc', NA, NA
    ),
    control_3 = c(
      NA, 
      NA, NA, 
      NA, NA, 
      NA, 'serc',NA, NA, NA, NA, NA, 
      NA, 'serc',NA, NA, NA, NA, NA, 
      NA, 'serc',NA, NA, NA, NA, NA, 
      NA, 'rfc',NA, NA, NA, NA, NA
    ),
    square = FALSE
  )

fml_dt =
  rbind(
    fml_dt_raw[,.(nerc_adj, control_1, control_2, control_3, terms = 'quadratic')],
    fml_dt_raw[,.(nerc_adj, control_1, control_2, control_3, terms = 'linear')],
    fml_dt_raw[
      is.na(control_1) & is.na(control_2) & is.na(control_3),
      .(nerc_adj, control_1, control_2, control_3, terms = 'constant')
    ]
  )[,
    ic_formula := fcase(
      terms == 'linear', 
      paste(
        paste0('gload_mwh ~ excess_load_',tolower(nerc_adj)),
        paste0('excess_load_',control_1),
        paste0('excess_load_',control_2),
        paste0('excess_load_',control_3),
        sep = ' + '
      ),
      terms == 'quadratic', 
      paste(
        paste0('gload_mwh ~ excess_load_',tolower(nerc_adj)),
        paste0('excess_load_',tolower(nerc_adj),'_sq'),
        paste0('excess_load_',control_1),
        paste0('excess_load_',control_2),
        paste0('excess_load_',control_3),
        sep = ' + '
      ),
      terms == 'constant', 'gload_mwh ~ 1'
    ) |> str_remove_all(' \\+ excess_load_NA')
  ][,
    own_region_only := is.na(control_1) & is.na(control_2) & is.na(control_3)
  ] 

# Loading nerc data to see max observed values 
max_nerc_load_dt = 
  read.fst(
    path = here("Data/electricity-generation/clean-load/nerc-load-dt.fst"),
    as.data.table = TRUE
  )[!is.na(nerc_adj) & year(utc_time) == 2019,.(
    max_eload = max(excess_load)), 
    keyby = nerc_adj
  ]

#orispl_code_in = 400#elec_gen_dt$orispl_code[1]
#unitid_in = '2'#elec_gen_dt$unitid[1]

# function to fit model for single unit 
fit_unit = function(orispl_code_in, unitid_in){
  print(paste(orispl_code_in, unitid_in))
  # Reading data for a single unit
  unit_gen_dt = elec_gen_dt[.(orispl_code_in, unitid_in)]
  # Running models for all of the formulas 
  unit_mod_coef_dt = 
    map_dfr(
      fml_dt[nerc_adj == unit_gen_dt[1]$nerc_adj]$ic_formula,
      \(fml){
        # Fitting the model
        unit_mod = 
          AER::tobit(
            formula = as.formula(fml), 
            left = 0,
            right = unit_gen_dt$namepcap[1],
            data = unit_gen_dt
          )
        # Collecting results 
        return(
          data.table(
            rn = c(names(unit_mod$coefficients),'log_scale'),
            estimate = c(unit_mod$coefficients, log(unit_mod$scale))
          )%>% 
          .[,':='(
            orispl_code = orispl_code_in, 
            unitid = unitid_in,
            fml = fml, 
            loglik = unit_mod$loglik[2],
            loglik_df = unit_mod$df
          )]
        )
      }
    )
  # Saving the raw model results 
  write.fst(
    x = unit_mod_coef_dt,
    path = here(paste0(
      "Data/electricity-generation/unit-model-fit/coefs/unit-mod-coefs-",
      orispl_code_in,"-",unit_gen_dt$clean_unitid[1],".fst"
    ))
  )
  # Determining the preferred model
  # - All coefficients must be positive 
  # - Must be increasing over range of observed loads in own region  
  # - Picking model with highest log-liklihood conditional on above
  # Casting to wide, now have one row for every model 
  coef_dt = 
    unit_mod_coef_dt |>
    dcast(
      orispl_code + unitid + fml + loglik + loglik_df ~ rn,
      value.var = "estimate",
      fill = 0
    ) |>
    setnames(
      old = '(Intercept)',
      new = 'intercept'
    ) %>% .[,':='(
      has_sq_term = str_detect(fml, 'sq')
    )]
  # Adding missing columns (excess load outside of interconnection)
  missing_cols = 
    data.table()[,
      c(paste0('excess_load_',c('cal','mro','npcc','rfc','serc','tre','wecc')),
        paste0('excess_load_',c('cal','mro','npcc','rfc','serc','tre','wecc'),'_sq')) 
        := rep(0, nrow(coef_dt))
    ] |> 
    select(-any_of(colnames(coef_dt)))
  coef_dt = cbind(coef_dt,missing_cols)
  # Merging with unit info to get nerc_adj and to max excess load for each region
  all_model_dt =
    merge(
      coef_dt, 
      unit_info_dt,
      by = c('orispl_code','unitid')
    ) |>
    merge(
      max_nerc_load_dt,
      by = c("nerc_adj")
    )
  # Checking that derivative wrt own region excess load is
  # positive at 0 and at max excess load
  # Also checking that all coefs are postive
  all_model_dt[,':='(
    increasing_at_0 = fcase(
      nerc_adj == 'CAL',  excess_load_cal  >= 0,
      nerc_adj == 'MRO',  excess_load_mro  >= 0,
      nerc_adj == 'NPCC', excess_load_npcc >= 0,
      nerc_adj == 'RFC',  excess_load_rfc  >= 0,
      nerc_adj == 'SERC', excess_load_serc >= 0,
      nerc_adj == 'TRE',  excess_load_tre  >= 0,
      nerc_adj == 'WECC', excess_load_wecc >= 0
    ),
    increasing_at_max = fcase(
      nerc_adj == 'CAL',  excess_load_cal  + excess_load_cal_sq*max_eload  >= 0,
      nerc_adj == 'MRO',  excess_load_mro  + excess_load_mro_sq*max_eload  >= 0,
      nerc_adj == 'NPCC', excess_load_npcc + excess_load_npcc_sq*max_eload >= 0,
      nerc_adj == 'RFC',  excess_load_rfc  + excess_load_rfc_sq*max_eload  >= 0,
      nerc_adj == 'SERC', excess_load_serc + excess_load_serc_sq*max_eload >= 0,
      nerc_adj == 'TRE',  excess_load_tre  + excess_load_tre_sq*max_eload  >= 0,
      nerc_adj == 'WECC', excess_load_wecc + excess_load_wecc_sq*max_eload >= 0
    ),
    coefs_positive = 
      excess_load_cal  >= 0 & 
      excess_load_mro  >= 0 & 
      excess_load_npcc >= 0 & 
      excess_load_rfc  >= 0 & 
      excess_load_serc >= 0 & 
      excess_load_tre  >= 0 & 
      excess_load_wecc >= 0 
  )][,# Creating indicator for models that fit our restrictions 
    good_model := coefs_positive & increasing_at_0 & increasing_at_max
  ]
  # Choosing the best (highest log liklihood) model among the good ones
  # (Excludes units that have **NO** good models)
  good_model_dt = all_model_dt[
    good_model == TRUE, 
    .SD[which.max(loglik)], 
    by = .(orispl_code, unitid)
  ]
  # Saving as wide table with variables in regression as columns
  # Saving the results
  write.fst(
    x = good_model_dt,
    path = here(paste0(
      "Data/electricity-generation/unit-model-fit/good-model-dt/good-model-dt-",
      orispl_code_in,"-",unit_gen_dt$clean_unitid[1],".fst"
    ))
  )
  # Repeating for a model with just own region load 
  own_region_model_dt = merge(
    all_model_dt,
    fml_dt[,.(ic_formula, own_region_only)] |> unique(),
    by.x = 'fml', by.y = 'ic_formula',
  )[own_region_only == TRUE & good_model == TRUE,
    .SD[which.max(loglik)], 
    by = .(orispl_code, unitid)
  ]
  # Saving as wide table with variables in regression as columns
  # Saving the results
  write.fst(
    x = own_region_model_dt,
    path = here(paste0(
      "Data/electricity-generation/unit-model-fit/own-region-model-dt/own-region-model-dt-",
      orispl_code_in,"-",unit_gen_dt$clean_unitid[1],".fst"
    ))
  )
  # Calculating raw fitted values
  unit_gen_dt[,':='(
    fit_gload_mwh_raw = good_model_dt$intercept + 
      good_model_dt$excess_load_cal*excess_load_cal +
      good_model_dt$excess_load_cal_sq*excess_load_cal_sq +
      good_model_dt$excess_load_mro*excess_load_mro +
      good_model_dt$excess_load_mro_sq*excess_load_mro_sq +
      good_model_dt$excess_load_npcc*excess_load_npcc +
      good_model_dt$excess_load_npcc_sq*excess_load_npcc_sq +
      good_model_dt$excess_load_rfc*excess_load_rfc +
      good_model_dt$excess_load_rfc_sq*excess_load_rfc_sq +
      good_model_dt$excess_load_serc*excess_load_serc +
      good_model_dt$excess_load_serc_sq*excess_load_serc_sq +
      good_model_dt$excess_load_tre*excess_load_tre +
      good_model_dt$excess_load_tre_sq*excess_load_tre_sq +
      good_model_dt$excess_load_wecc*excess_load_wecc +
      good_model_dt$excess_load_wecc_sq*excess_load_wecc_sq,
    fit_gload_mwh_raw_own_region = own_region_model_dt$intercept + 
      own_region_model_dt$excess_load_cal*excess_load_cal +
      own_region_model_dt$excess_load_cal_sq*excess_load_cal_sq +
      own_region_model_dt$excess_load_mro*excess_load_mro +
      own_region_model_dt$excess_load_mro_sq*excess_load_mro_sq +
      own_region_model_dt$excess_load_npcc*excess_load_npcc +
      own_region_model_dt$excess_load_npcc_sq*excess_load_npcc_sq +
      own_region_model_dt$excess_load_rfc*excess_load_rfc +
      own_region_model_dt$excess_load_rfc_sq*excess_load_rfc_sq +
      own_region_model_dt$excess_load_serc*excess_load_serc +
      own_region_model_dt$excess_load_serc_sq*excess_load_serc_sq +
      own_region_model_dt$excess_load_tre*excess_load_tre +
      own_region_model_dt$excess_load_tre_sq*excess_load_tre_sq +
      own_region_model_dt$excess_load_wecc*excess_load_wecc +
      own_region_model_dt$excess_load_wecc_sq*excess_load_wecc_sq
  )]
  # Drawing epsilons 
  N = nrow(unit_gen_dt)
  unit_gen_dt[,':='(
    epsilon = rnorm(N, mean = 0, sd = exp(good_model_dt$log_scale)),
    epsilon_own_region = rnorm(N, mean = 0, sd = exp(own_region_model_dt$log_scale))
  )]
  # Censoring the fitted values
  unit_gen_dt[,':='(
    fit_gload_mwh = fcase(
      fit_gload_mwh_raw + epsilon >= 0 & fit_gload_mwh_raw + epsilon <= namepcap, 
        fit_gload_mwh_raw + epsilon,
      fit_gload_mwh_raw + epsilon < 0, 0,
      fit_gload_mwh_raw + epsilon > namepcap, namepcap
    ),
    fit_gload_mwh_own_region = fcase(
      fit_gload_mwh_raw_own_region + epsilon_own_region >= 0 
      & fit_gload_mwh_raw_own_region + epsilon_own_region <= namepcap, 
        fit_gload_mwh_raw_own_region + epsilon_own_region,
      fit_gload_mwh_raw_own_region + epsilon_own_region < 0, 0,
      fit_gload_mwh_raw_own_region + epsilon_own_region > namepcap, namepcap
    )
  )]
  # Saving the results
  write.fst(
    x = unit_gen_dt[,.(
      orispl_code, unitid, utc_time, 
      gload_mwh, fit_gload_mwh,
      op_time, gload_mw, 
      so2_mass_lbs, nox_mass_lbs, co2_mass_tons, 
      heat_input_mm_btu,sload_1000lb_hr,
      fit_gload_mwh_raw, epsilon,
      fit_gload_mwh_own_region, fit_gload_mwh_raw_own_region, epsilon_own_region
    )],
    path = here(paste0(
      "Data/electricity-generation/unit-model-fit/unit-gen-dt/unit-gen-dt-",
      orispl_code_in,"-",unit_gen_dt$clean_unitid[1],".fst"
    ))
  )
  
  
}

# Deleting old files 
unlink(here("Data/electricity-generation/unit-model-fit/coefs"), recursive = TRUE)
unlink(here("Data/electricity-generation/unit-model-fit/unit-gen-dt"), recursive = TRUE)
unlink(here("Data/electricity-generation/unit-model-fit/good-model-dt"), recursive = TRUE)
unlink(here("Data/electricity-generation/unit-model-fit/own-region-dt"), recursive = TRUE)
dir.create(here("Data/electricity-generation/unit-model-fit/coefs"))
dir.create(here("Data/electricity-generation/unit-model-fit/unit-gen-dt"))
dir.create(here("Data/electricity-generation/unit-model-fit/good-model-dt"))
dir.create(here("Data/electricity-generation/unit-model-fit/own-region-dt"))

# Getting list of units not yet run 
unit_dt_run = 
  data.table(
    orispl_code = 
      list.files(here("Data/electricity-generation/unit-model-fit/coefs")) |> 
      str_remove("unit-mod-coefs-") |>
      str_extract("\\w*(?=-)") |>
      as.integer(),
    clean_unitid = list.files(here("Data/electricity-generation/unit-model-fit/coefs")) |> 
      str_remove("unit-mod-coefs-") |>
      str_extract("(?<=-)\\w*(?=\\.fst)") |>
      as.character(),
    already_run = TRUE
  )

unit_dt_not_run = 
  merge(
    unit_info_dt, 
    unit_dt_run, 
    by = c("orispl_code","clean_unitid"),
    all.x = TRUE
  )[is.na(already_run) 
    ,.(orispl_code, unitid)
  ]

# Runnning for all units 
pmap(
  unit_dt_not_run,
  fit_unit
)
invisible(gc())

# Combining all of the model files 
unit_mod_coef_dt = 
  map_dfr(
    list.files(
      here("Data/electricity-generation/unit-model-fit/coefs"), 
      full.names = TRUE
    ),
    read.fst, 
    as.data.table = TRUE
  )
write.fst(
  unit_mod_coef_dt,
  path = here("Data/electricity-generation/unit-model-fit/unit-mod-coef-dt.fst")
)

# Combining all of the good model files 
good_model_dt = 
  map_dfr(
    list.files(
      here("Data/electricity-generation/unit-model-fit/good-model-dt"), 
      full.names = TRUE
    ),
    read.fst, 
    as.data.table = TRUE
  )

good_model_dt[,':='(
  est_extremum = fcase(
    str_detect(fml, 'cal_sq'), -excess_load_cal/(2*excess_load_cal_sq),
    str_detect(fml, 'mro_sq'), -excess_load_mro/(2*excess_load_mro_sq),
    str_detect(fml, 'npcc_sq'), -excess_load_npcc/(2*excess_load_npcc_sq),
    str_detect(fml, 'rfc_sq'), -excess_load_rfc/(2*excess_load_rfc_sq),
    str_detect(fml, 'serc_sq'), -excess_load_serc/(2*excess_load_serc_sq),
    str_detect(fml, 'tre_sq'), -excess_load_tre/(2*excess_load_tre_sq),
    str_detect(fml, 'wecc_sq'), -excess_load_wecc/(2*excess_load_wecc_sq)
  )
)]


write.fst(
  good_model_dt,
  path = here("Data/electricity-generation/unit-model-fit/good-model-dt.fst")
)


rm(
  elec_gen_dt, unit_info_dt, 
  unit_dt_run, unit_dt_not_run,
  fml_dt, fml_dt_raw, 
  good_model_dt, max_nerc_load_dt
)
invisible(gc())

# Hourly emissions data
elec_gen_dt =
  map_dfr(
    list.files(
      here("Data/electricity-generation/unit-model-fit/unit-gen-dt"), 
      full.names = TRUE
    ),
    read.fst
  )|> 
  data.table() |>
  setkey(orispl_code, unitid, utc_time) 
write.fst(
  elec_gen_dt,
  path = here("Data/electricity-generation/unit-model-fit/elec-gen-fit-dt.fst")
)


